Week 9: Inference optimization

Applied Generative AI for AI Developers

Amit Arora

Agenda

  • Introduction to Inference Optimization
  • Quantization Techniques
    • Post-Training Quantization
    • Quantization-Aware Training
    • Weight-Only Quantization
    • Activation-Aware Quantization
  • Model Distillation
  • Speculative Decoding
  • KV Cache Optimization
  • Flash Attention
  • Pruning
  • Sparse Inference
  • Tensor Parallelism & Sharding
  • Continuous Batching
  • References & Further Reading

Introduction to Inference Optimization

  • LLMs are computationally expensive
    • GPT-3 (175B params) → GPT-4 (1.8T params estimate)
  • Inference optimization critical for:
    • Reducing latency
    • Decreasing memory footprint
    • Lowering deployment costs
    • Enabling edge deployment

Key metrics: Latency, Throughput, Memory, Cost

Post-Training Quantization (PTQ)

  • Applied after model training is complete
  • Reduces precision of weights/activations
    • FP32 → FP16 / BF16 / INT8 / INT4
  • No retraining required
  • Minimal accuracy loss with proper calibration
  • Example: NVIDIA TensorRT, ONNX Runtime

Trade-offs: Quick to implement but may introduce accuracy degradation

# PyTorch example
import torch

# Original model (FP32)
model_fp32 = load_model()

# Quantize to INT8
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,
    {torch.nn.Linear},
    dtype=torch.qint8
)

Quantization-Aware Training (QAT)

  • Simulates quantization effects during training
  • Model learns to be robust to quantization noise
  • Typically better performance than PTQ
  • Requires retraining the model
  • Frameworks: TensorFlow, PyTorch, HuggingFace Optimum

Key advantage: Better quality-performance trade-off

# PyTorch QAT example
import torch

# Prepare model for QAT
model.train()
model_qat = torch.quantization.prepare_qat(
    model, 
    inplace=False
)

# Train with quantization in the loop
train_loop(model_qat, train_data)

# Convert to quantized model
model_quantized = torch.quantization.convert(
    model_qat, 
    inplace=False
)

Weight-Only Quantization

  • Quantizes only the model weights, not activations
  • Popular for LLMs (GPTQ, GGML, AWQ, LLM.int8())
  • Less intrusive than full quantization
  • Typical formats: INT8, INT4, or mixed precision
  • Significant memory reduction with minimal quality loss
  • Enables running models on consumer hardware

Example: LLaMA 7B from 28GB (FP16) to 4GB (INT4)

# Using GPTQ with HuggingFace
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b",
    device_map="auto",
    load_in_4bit=True,
    quantization_config={
        "bits": 4,
        "group_size": 128
    }
)

Activation-Aware Quantization

  • Considers activation distributions when quantizing weights
  • Balances weight precision based on activation sensitivity
  • Examples: Activation-aware Weight Quantization (AWQ)
  • Optimizes quantization per channel/tensor
  • Preserves performance on critical network paths

Key insight: Not all weights equally impact model performance

AWQ Source: AWQ paper

# Conceptual AWQ process
# 1. Identify activation-sensitive weights
# 2. Apply higher precision to sensitive weights
# 3. Use lower precision for less sensitive weights

Model Distillation

  • Transfers knowledge from larger teacher to smaller student
  • Student mimics teacher’s behavior, not architecture
  • Types:
    • Response-based (output matching)
    • Feature-based (intermediate layers)
    • Relation-based (attention maps)
  • Examples: DistilBERT, TinyBERT, MiniLM

Results: 40-60% size reduction with 90-95% performance

# Distillation loss example
def distillation_loss(student_logits, 
                      teacher_logits, 
                      temperature=2.0):
    """
    KL divergence between teacher and 
    student softmax distributions
    """
    soft_teacher = F.softmax(
        teacher_logits / temperature, 
        dim=-1
    )
    soft_student = F.softmax(
        student_logits / temperature, 
        dim=-1
    )
    return F.kl_div(
        F.log_softmax(soft_student, dim=-1),
        soft_teacher,
        reduction='batchmean'
    ) * (temperature ** 2)

Speculative Decoding

  • Uses smaller draft model to predict tokens
  • Validates with larger target model in parallel
  • Dramatically increases generation speed (2-3x)
  • Examples:
    • Medusa (multiple decoding heads)
    • Lookahead decoding
    • Google’s Speculative Decoding

Key idea: Smaller models can accurately predict some tokens

Speculative decoding Source: Speculative Decoding paper

  1. Draft model generates K tokens
  2. Target model verifies predictions
  3. Accept correct predictions
  4. Regenerate from first incorrect token

KV Cache Optimization

  • KV cache: Stored key-value pairs from attention layers
  • Grows linearly with sequence length
  • Optimization approaches:
    • Attention sinks (preserve important tokens)
    • Sliding window attention
    • KV cache pruning
    • Compression techniques

Impact: Reduces memory usage by 20-70%

# Conceptual sliding window attention
def sliding_window_attention(
    query, key, value, window_size=1024
):
    seq_len = query.shape[1]
    
    if seq_len <= window_size:
        # Standard attention
        return standard_attention(query, key, value)
    
    # Only attend to recent window_size tokens
    recent_keys = key[:, -window_size:, :]
    recent_values = value[:, -window_size:, :]
    
    return standard_attention(
        query, recent_keys, recent_values
    )

Pruning

  • Removes redundant weights/neurons/attention heads
  • Types:
    • Unstructured: Individual weights
    • Structured: Entire neurons/heads/layers
    • Magnitude-based: Remove smallest weights
    • Importance-based: Remove least impactful units
  • Iterative process: prune → fine-tune → repeat

Results: 30-90% parameter reduction possible

# PyTorch example: magnitude pruning
import torch.nn.utils.prune as prune

# Prune 30% of weights by magnitude
prune.l1_unstructured(
    module.weight, 
    name="weight", 
    amount=0.3
)

# Make pruning permanent
prune.remove(module, "weight")

# Fine-tune the pruned model
train_loop(model, train_data)

Sparse Inference

  • Exploits sparsity in model architecture
  • Approaches:
    • Sparse attention patterns (BigBird, Longformer)
    • Mixture of Experts (MoE)
    • Structured sparsity (block sparsity)
    • Dynamic computation paths
  • Hardware acceleration: NVIDIA Sparse Tensor Cores

Challenge: Requires specialized hardware/libraries

# Mixture of Experts conceptual example
class SparselyGatedMoE(nn.Module):
    def __init__(self, input_size, output_size, 
                 num_experts=8, k=2):
        # Initialize experts and router
        self.experts = nn.ModuleList([
            nn.Linear(input_size, output_size) 
            for _ in range(num_experts)
        ])
        self.router = nn.Linear(input_size, num_experts)
        self.k = k  # Top-k experts to use
        
    def forward(self, x):
        # Get router scores and select top-k experts
        router_logits = self.router(x)
        k_logits, indices = router_logits.topk(self.k)
        # Only compute selected expert outputs
        # ...

Tensor Parallelism & Sharding

  • Distributes model across multiple devices
  • Types:
    • Tensor Parallelism: Split individual layers
    • Pipeline Parallelism: Different layers on different devices
    • Zero Redundancy Optimizer (ZeRO): Shard optimizer states
  • Frameworks: DeepSpeed, Megatron-LM, PyTorch FSDP

Benefits: Enables inference of models too large for single GPU

Tensor Parallelism Source: Hugging Face

Example: - 175B parameter model - 16 GPUs with tensor parallelism - Each GPU handles ~11B parameters - Coordinated through collective communication

Flash Attention

  • Optimizes attention computation for transformers
  • Key benefits:
    • Reduces memory I/O cost by factor of √N
    • Optimization through tiling and recomputation
    • Works with causal/bidirectional/cross-attention
  • FlashAttention-2: Further optimizations
  • Hardware-aware approach for GPU acceleration
  • Supporting long context windows efficiently

Impact: 2-4x faster attention computation, enables longer contexts

# Using Flash Attention in PyTorch
from flash_attn import flash_attn_func

# Replace standard attention with:
attn_output = flash_attn_func(
    q,            # query
    k,            # key
    v,            # value
    dropout_p=0.0,
    causal=True   # for decoder-only models
)

# With HuggingFace Transformers:
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b",
    attn_implementation="flash_attention_2"
)

Continuous Batching / Dynamic Batching

  • Traditional batching: wait for batch to fill
  • Continuous batching: process requests as they arrive
  • Techniques:
    • Iteration-level scheduling
    • Paged Attention (vLLM)
    • Continuous batching with beam search
  • Dramatically improves throughput at scale

Impact: 2-10x throughput improvement

CB Continuous Batching

Key metrics: - Time-to-first-token (TTFT) - Time-per-output-token (TPOT) - Tokens-per-second (TPS) - Cost per 1M tokens